# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# Tests of end-to-end IREE support for individual ops in the TOSA dialect.
# Each test file should have a name matching the corresponding TOSA op and test only the
# functionality of that op (though may make use of other ops where necessary). Tests should be
# written using the IREE Check framework.
# See https://iree.dev/developers/general/testing-guide/#iree-core-end-to-end-e2e-tests.

load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
load("//build_tools/bazel:iree_check_test.bzl", "iree_check_single_backend_test_suite")

package(
    features = ["layering_check"],
    licenses = ["notice"],  # Apache 2.0
)

LLVM_SRCS = enforce_glob(
    # keep sorted
    [
        "conv2d.mlir",
        "fp4_f32_conversion.mlir",
        "fp_to_subbyte.mlir",
        "gather_like_ops.mlir",
        "narrow_n_matmuls.mlir",
        "pack.mlir",
        "pack_dynamic_inner_tiles.mlir",
        "pack_i8.mlir",
        "softmax.mlir",
        "subbyte_to_fp.mlir",
        "unpack.mlir",
    ],
    include = ["*.mlir"],
    exclude = [
        "argmax.mlir",
        "index.mlir",
        "large_linalg_matmul.mlir",
    ],
)

iree_check_single_backend_test_suite(
    name = "check_llvm-cpu_local-task",
    srcs = LLVM_SRCS,
    compiler_flags = ["--iree-llvmcpu-target-cpu=generic"],
    driver = "local-task",
    tags = [
        # subbyte support for wasm is not on priorities.
        "nowasm",
    ],
    target_backend = "llvm-cpu",
)

# TODO(#19378): Delete the test suite after we switch to data-tiling fusion for
# tests/e2e/matmul test suite.
iree_check_single_backend_test_suite(
    name = "check_llvm-cpu_dt_fusion_local-task",
    srcs = ["narrow_n_matmuls.mlir"],
    compiler_flags = [
        "--iree-dispatch-creation-data-tiling",
        "--iree-llvmcpu-target-cpu=generic",
    ],
    driver = "local-task",
    tags = [
        # subbyte support for wasm is not on priorities.
        "nowasm",
    ],
    target_backend = "llvm-cpu",
)

VMVX_SRCS = enforce_glob(
    # keep sorted
    [
        "conv2d.mlir",
        "gather_like_ops.mlir",
        "narrow_n_matmuls.mlir",
        "pack.mlir",
        "pack_dynamic_inner_tiles.mlir",
        "pack_i8.mlir",
        "softmax.mlir",
        "unpack.mlir",
    ],
    include = ["*.mlir"],
    exclude = [
        "argmax.mlir",
        "fp_to_subbyte.mlir",
        "fp4_f32_conversion.mlir",
        "index.mlir",
        "large_linalg_matmul.mlir",
        "subbyte_to_fp.mlir",
    ],
)

iree_check_single_backend_test_suite(
    name = "check_vmvx_local-task",
    srcs = VMVX_SRCS,
    driver = "local-task",
    target_backend = "vmvx",
)

iree_check_single_backend_test_suite(
    name = "check_vmvx_ukernel_local-task",
    srcs = [
        "pack.mlir",
        "pack_dynamic_inner_tiles.mlir",
        "unpack.mlir",
    ],
    compiler_flags = [
        "--iree-vmvx-enable-microkernels",
        # Some testcases have linalg.generic ops with multiple ops in the body.
        # If we don't opt out from it, DecomposeLinalgGenericPass splits those
        # into smaller linalg.generic ops with only one op in the body. This
        # results in the creation of temporary buffers between these split
        # linalg.generic ops, causing:
        # > error: failed to legalize operation 'memref.alloca' that was explicitly marked illegal
        "--iree-vmvx-enable-ukernels-decompose-linalg-generic=false",
    ],
    driver = "local-task",
    target_backend = "vmvx",
)

VULKAN_SRCS = enforce_glob(
    # keep sorted
    [
        "conv2d.mlir",
        "gather_like_ops.mlir",
        "narrow_n_matmuls.mlir",
        "softmax.mlir",
        "subbyte_to_fp.mlir",
    ],
    include = ["*.mlir"],
    exclude = [
        "argmax.mlir",
        "fp_to_subbyte.mlir",
        "fp4_f32_conversion.mlir",
        "index.mlir",
        "large_linalg_matmul.mlir",
        "pack.mlir",
        "pack_dynamic_inner_tiles.mlir",
        "pack_i8.mlir",
        "unpack.mlir",
    ],
)

iree_check_single_backend_test_suite(
    name = "check_vulkan-spirv_vulkan",
    srcs = VULKAN_SRCS,
    driver = "vulkan",
    target_backend = "vulkan-spirv",
)

WINOGRAD_CONV_SRCS = [
    "conv2d.mlir",
]

iree_check_single_backend_test_suite(
    name = "check_winograd_llvm-cpu_local-task",
    srcs = WINOGRAD_CONV_SRCS,
    compiler_flags = [
        "--iree-preprocessing-pass-pipeline=builtin.module\\(func.func\\(iree-linalg-ext-convert-conv2d-to-winograd\\)\\)",
        "--iree-llvmcpu-target-cpu=generic",
    ],
    driver = "local-task",
    target_backend = "llvm-cpu",
)

iree_check_single_backend_test_suite(
    name = "check_winograd_vulkan-spirv_vulkan",
    srcs = WINOGRAD_CONV_SRCS,
    compiler_flags = [
        "--iree-preprocessing-pass-pipeline=builtin.module\\(func.func\\(iree-linalg-ext-convert-conv2d-to-winograd\\)\\)",
    ],
    driver = "vulkan",
    target_backend = "vulkan-spirv",
)

CUDA_SRCS = enforce_glob(
    # keep sorted
    [
        "conv2d.mlir",
        "fp_to_subbyte.mlir",
        "subbyte_to_fp.mlir",
        # currently only enabled on cuda as it can be slow on other backends.
        "large_linalg_matmul.mlir",
        "narrow_n_matmuls.mlir",
        "pack_i8.mlir",
        "softmax.mlir",
        "unpack.mlir",
    ],
    include = ["*.mlir"],
    exclude = [
        "argmax.mlir",
        "fp4_f32_conversion.mlir",
        "gather_like_ops.mlir",
        "index.mlir",
        # https://github.com/llvm/llvm-project/issues/131386 causes
        # See bug #20294
        "pack.mlir",
        "pack_dynamic_inner_tiles.mlir",
    ],
)

iree_check_single_backend_test_suite(
    name = "check_large_linalg_matmul_cuda",
    srcs = CUDA_SRCS,
    driver = "cuda",
    tags = [
        # CUDA cuInit fails with sanitizer on.
        "noasan",
        "nomsan",
        "notsan",
        "noubsan",
        "requires-gpu-nvidia",
    ],
    target_backend = "cuda",
)

ROCM_SRCS = enforce_glob(
    # keep sorted
    [
        "argmax.mlir",
        "gather_like_ops.mlir",
        "pack_i8.mlir",
        "softmax.mlir",
        "unpack.mlir",
    ],
    include = ["*.mlir"],
    exclude = [
        "conv2d.mlir",
        "fp_to_subbyte.mlir",
        "fp4_f32_conversion.mlir",
        "index.mlir",
        "large_linalg_matmul.mlir",
        "narrow_n_matmuls.mlir",
        "subbyte_to_fp.mlir",
        # https://github.com/llvm/llvm-project/issues/131386 causes
        # See bug #20294
        "pack.mlir",
        "pack_dynamic_inner_tiles.mlir",
    ],
)

iree_check_single_backend_test_suite(
    name = "check_rocm_hip",
    srcs = ROCM_SRCS,
    driver = "hip",
    target_backend = "rocm",
)

INDEX_SRCS = [
    "index.mlir",
]

iree_check_single_backend_test_suite(
    name = "check_index_llvm-cpu_local-task",
    srcs = INDEX_SRCS,
    compiler_flags = ["--iree-llvmcpu-target-cpu=generic"],
    driver = "local-task",
    tags = [
        # indexing math generates illegal instructions for riscv
        "noriscv",
    ],
    target_backend = "llvm-cpu",
)

ARGMAX_SRCS = [
    "argmax.mlir",
]

iree_check_single_backend_test_suite(
    name = "check_argmax_llvm-cpu_local-task",
    srcs = ARGMAX_SRCS,
    compiler_flags = ["--iree-llvmcpu-target-cpu=generic"],
    driver = "local-task",
    target_backend = "llvm-cpu",
)

test_suite(
    name = "check",
    tests = [
        ":check_argmax_llvm-cpu_local-task",
        ":check_index_llvm-cpu_local-task",
        ":check_large_linalg_matmul_cuda",
        ":check_llvm-cpu_local-task",
        ":check_vmvx_local-task",
        ":check_vulkan-spirv_vulkan",
        ":check_winograd_llvm-cpu_local-task",
        ":check_winograd_vulkan-spirv_vulkan",
    ],
)
